package xin.nick.rpc.provider.handler;

import com.alibaba.fastjson.JSON;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeansException;
import org.springframework.cglib.reflect.FastClass;
import org.springframework.cglib.reflect.FastMethod;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.stereotype.Service;
import xin.nick.rpc.common.RpcRequest;
import xin.nick.rpc.common.RpcResponse;
import xin.nick.rpc.provider.anno.RpcService;

import java.lang.reflect.InvocationTargetException;
import java.net.InetSocketAddress;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;

/**
 * @author Nick
 * @date 2021/7/4
 * @description 服务端业务处理类
 */
@Slf4j
@Service
@ChannelHandler.Sharable
public class RpcServerHandler extends SimpleChannelInboundHandler<String> implements ApplicationContextAware {
    /**
     * 将标有@Rpcservice注解的bean缓存
     * 接收客户端请求
     * 根据传递过来的bean从缓存中查找相对应的bean
     * 解析请求中的方法名称,参数类型,参数信息
     * 反射调用bean方法
     * 返回响应数据
     */

    private static final Map SERVICE_INSTANCE_MAP = new ConcurrentHashMap();


    @Override
    public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
        super.handlerAdded(ctx);
        InetSocketAddress ipSocket = (InetSocketAddress)ctx.channel().remoteAddress();
        String clientIp = ipSocket.getAddress().getHostAddress();
        int clientPort = ipSocket.getPort();
        log.info("---> 新客户端: {} : {}" ,clientIp, clientPort);
    }

    @Override
    public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
        super.handlerRemoved(ctx);
        InetSocketAddress ipSocket = (InetSocketAddress)ctx.channel().remoteAddress();
        String clientIp = ipSocket.getAddress().getHostAddress();
        int clientPort = ipSocket.getPort();
        log.info("---> 客户端断开: {} : {}" ,clientIp, clientPort);
    }

    /**
     * 通道读取就绪事件
     * @param channelHandlerContext
     * @param msg
     * @throws Exception
     */
    @Override
    protected void channelRead0(ChannelHandlerContext channelHandlerContext, String msg) throws Exception {

        // 接收请求,将msg转化为RpcRequest对象
        RpcRequest rpcRequest = JSON.parseObject(msg, RpcRequest.class);
        RpcResponse rpcResponse = new RpcResponse();
        rpcResponse.setRequestId(rpcRequest.getRequestId());

        try {
            rpcResponse.setResult(handler(rpcRequest));
        } catch (Exception e) {
            e.printStackTrace();
            rpcResponse.setError(e.getMessage());
        }

        // 响应
        channelHandlerContext.writeAndFlush(JSON.toJSONString(rpcResponse));
    }

    /**
     * 业务逻辑处理
     * @param rpcRequest
     * @return
     */
    private Object handler(RpcRequest rpcRequest) throws InvocationTargetException {
        // 找到相对应的bean
        Object serviceBean = SERVICE_INSTANCE_MAP.get(rpcRequest.getClassName());
        if(Objects.isNull(serviceBean)) {
            throw new RuntimeException("根据beanName找不到服务,beanName:" + rpcRequest.getClassName());
        }

        // 解析方法名,参数类型,参数信息
        Class<?> serviceBeanClass = serviceBean.getClass();
        String methodName = rpcRequest.getMethodName();
        Class<?>[] parameterTypes = rpcRequest.getParameterTypes();
        Object[] parameters = rpcRequest.getParameters();

        // 反射调用bean方法-CGLIB
        FastClass fastClass = FastClass.create(serviceBeanClass);
        FastMethod method = fastClass.getMethod(methodName, parameterTypes);
        return method.invoke(serviceBean, parameters);

    }

    @Override
    public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
        /**
         * 在这里缓存
         */
        Map<String, Object> serviceMap = applicationContext.getBeansWithAnnotation(RpcService.class);


        if (serviceMap != null && serviceMap.size() > 0) {
            Set<Map.Entry<String, Object>> entrySet = serviceMap.entrySet();
            for (Map.Entry<String, Object> item : entrySet) {
                Object serviceBean = item.getValue();
                if (serviceBean.getClass().getInterfaces().length == 0) {
                    throw new RuntimeException("服务必须实现接口");
                }
                // 默认接收一个接口作为存储bean的名称
                String name = serviceBean.getClass().getInterfaces()[0].getName();
                SERVICE_INSTANCE_MAP.put(name, serviceBean);
            }
        }

    }


}

